import torch
import torch.optim as optim
from torch import nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

from algorithm.JOBCD.Parallel_updateX import Parallel_updateV
from UHstruct.fobj_val import UltraE_fval_obj
from UHstruct.get_hit_ranks import UltraE_test_hits_rank
from dataset.gettriple import Corrupt


class JJOBCD(nn.Module):
    def __init__(self, triple, ttest, vec_entity, vec_relation, vec_bias, config_yaml):
        super().__init__()
        self.device = config_yaml["device"]
        self.ttest = ttest
        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.Rmaxiter = int(float(config_yaml["run"]["Rmaxiter"]))
        self.Cmaxiter = int(float(config_yaml["run"]["Cmaxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.Lconstant = float(config_yaml["JOBCD"]["Lconstant"])
        self.theta = torch.tensor(int(float(config_yaml["JOBCD"]["theta"])))
        self.margin = config_yaml["datafeature"]["margin"]
        self.knum = config_yaml["datafeature"]["knum"]
        self.stopindex = 0

        self.triple = triple
        self.Corrupted_triple = None
        self.vec_entity = nn.Parameter(data=vec_entity)
        self.vec_relation = nn.Parameter(data=vec_relation)
        self.vec_bias = nn.Parameter(data=vec_bias)
        self.Rnum = vec_relation.shape[0]
        self.entity_num = self.vec_entity.shape[0]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32).to(self.device)
        self.lambdaa = 0
        self.myeps = 1e-8
        self.Xgrad = None
        self.optimizer = optim.Adam([self.vec_entity, self.vec_bias], lr=float(config_yaml["run"]["lr_eb"]),
                                    weight_decay=float(config_yaml["run"]["weight_decay"]))

    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_CObj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        hist_hits = torch.zeros(self.maxiter, 3).to(self.device)
        hist_MRR = torch.zeros(self.maxiter, 1).to(self.device)
        start_time = time.time()
        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            # get sample
            self.Corrupted_triple = Corrupt(self.triple, self.entity_num, self.knum)

            # train
            for Citer in tqdm(range(self.Cmaxiter), position=0):
                with torch.no_grad():
                    if Citer % self.config_yaml['run']['inner_dispgap'] == 0:
                        temp = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation,
                                               self.vec_bias,
                                               self.config_yaml)
                        print('JJOBCD  Before R iter:{},citer:{}/{}, obj:{:.4f}'.format(iter, Citer, self.Cmaxiter, temp))
                self.Xgrad = self.fgrad()
                self.innertrain(Citer)

            # log
            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                hist_Obj[iter] = UltraE_fval_obj(self.triple, self.Corrupted_triple,
                                                     self.vec_entity, self.vec_relation, self.vec_bias,
                                                     self.config_yaml)
                if iter == 0:
                    hist_CObj[iter] = hist_Obj[iter]
                else:
                    hist_CObj[iter] = hist_CObj[iter - 1] + hist_Obj[iter]

                hist_hits[iter, :], hist_MRR[iter] = UltraE_test_hits_rank(self.ttest, self.vec_entity, self.vec_relation,
                                                                           self.vec_bias,
                                                                           self.config_yaml)
                print('JJOBCD iter:{}, csumfval:{:.2f}, hist hits:{:.4f}-{:.4f}-{:.4f}, hist MRR:{:.4f}'.format(
                        iter, hist_CObj[iter][0].data, hist_hits[iter, :].data[0], hist_hits[iter, :].data[1],
                        hist_hits[iter, :].data[2], hist_MRR[iter][0].data))

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_CObj = hist_CObj[hist_CObj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_hits = hist_hits[:len(hist_Obj), :]
        hist_MRR = hist_MRR[:len(hist_Obj)]
        return hist_CObj, self.vec_relation, hist_t, hist_hits, hist_MRR

    def innertrain(self, Citer):
        self.allB = torch.zeros([self.Rnum, int(self.d / 2), 2]).to(self.device).to(torch.int64)
        for i in range(self.Rnum):
            original_vector = torch.randperm(self.d)
            original_vector = original_vector[torch.randperm(self.d)]
            self.allB[i, :, :] = original_vector.view(-1, 2)

        for i in range(self.Rnum):
            with torch.no_grad():
                self.vec_relation[i, :, :] = Parallel_updateV(self.vec_relation[i, :, :].clone().detach().to(torch.float32).to(self.device)
                                          , self.Xgrad[i, :, :].to(torch.float32), self.allB[i, :, :], self.Lconstant, self.theta, self.p)


    def fgrad(self):
        fval = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation, self.vec_bias,
                               self.config_yaml)
        fval.requires_grad_(True)
        fval.backward()

        if torch.isnan(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isnan(self.vec_entity.grad)] = 0
        if torch.isnan(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isnan(self.vec_bias.grad)] = 0
        if torch.isnan(self.vec_relation.grad).any():
            self.vec_relation.grad[torch.isnan(self.vec_relation.grad)] = 0

        if torch.isinf(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isinf(self.vec_entity.grad)] = 0
        if torch.isinf(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isinf(self.vec_bias.grad)] = 0
        if torch.isinf(self.vec_relation.grad).any():
            self.vec_relation.grad[torch.isinf(self.vec_relation.grad)] = 0
        clip_grad_norm_(self.vec_relation, max_norm=self.Lconstant)

        self.optimizer.step()
        grad_R = self.vec_relation.grad
        self.vec_entity.grad = None
        self.vec_relation.grad = None
        self.vec_bias.grad = None

        return grad_R


